{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "bec6975b",
      "metadata": {
        "id": "bec6975b"
      },
      "outputs": [],
      "source": [
        "import pandas as pd\n",
        "import numpy as np\n",
        "import ast\n",
        "import matplotlib.pyplot as plt\n",
        "import neurokit2 as nk\n",
        "import json\n",
        "import os"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "lClpN762IhKm",
      "metadata": {
        "id": "lClpN762IhKm"
      },
      "outputs": [],
      "source": [
        "df = pd.read_pickle('/path_to_heartbeats.pkl')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "wnuAlmWnF4e0",
      "metadata": {
        "id": "wnuAlmWnF4e0"
      },
      "outputs": [],
      "source": [
        "total_db_new = df[df['label']==0].head(0)\n",
        "for i in range(4):\n",
        " total_db_new = pd.concat([total_db_new,df[df['label']==i].sample(n=5000)],ignore_index=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0d576bd3",
      "metadata": {
        "id": "0d576bd3"
      },
      "outputs": [],
      "source": [
        "def resample_signal(row):\n",
        "    original_signal = row['signal']\n",
        "    resampled_signal = nk.signal_resample(original_signal, sampling_rate=250, desired_sampling_rate= 360, method=\"FFT\")\n",
        "    return pd.Series({'signal': resampled_signal, 'label': row['label']})"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "iMAWHE6QLPPl",
      "metadata": {
        "id": "iMAWHE6QLPPl"
      },
      "outputs": [],
      "source": [
        "resampled_df = total_db_new.apply(resample_signal, axis=1)\n",
        "del total_db_new\n",
        "total_db_new = resampled_df"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "g9fqfIyXE6CB",
      "metadata": {
        "id": "g9fqfIyXE6CB"
      },
      "source": [
        "# Split and Normalization"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ymYewx2pE8E7",
      "metadata": {
        "id": "ymYewx2pE8E7"
      },
      "outputs": [],
      "source": [
        "from sklearn.model_selection import train_test_split\n",
        "\n",
        "train_df, temp_df = train_test_split(total_db_new, test_size=0.3, random_state=42)\n",
        "val_df, test_df = train_test_split(temp_df, test_size=(2/3), random_state=42)\n",
        "\n",
        "train_df.reset_index(drop=True, inplace=True)\n",
        "val_df.reset_index(drop=True, inplace=True)\n",
        "test_df.reset_index(drop=True, inplace=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "6Q_TwZaddElo",
      "metadata": {
        "id": "6Q_TwZaddElo"
      },
      "outputs": [],
      "source": [
        "def find_max_min(current_df):\n",
        "  min_value = current_df['signal'].apply(lambda x: min(x)).min()\n",
        "  max_value = current_df['signal'].apply(lambda x: max(x)).max()\n",
        "  return min_value, max_value\n",
        "\n",
        "min_train_df, max_train_df = find_max_min(train_df)\n",
        "print(min_train_df, max_train_df)\n",
        "\n",
        "min_val_df, max_val_df = find_max_min(val_df)\n",
        "print(min_val_df, max_val_df)\n",
        "\n",
        "min_test_df, max_test_df = find_max_min(test_df)\n",
        "print(min_test_df, max_test_df)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "xQUEFk7yfB9m",
      "metadata": {
        "id": "xQUEFk7yfB9m"
      },
      "outputs": [],
      "source": [
        "train_df['signal'] = train_df['signal'].apply(lambda x: [(item - min_train_df) / (max_train_df - min_train_df) for item in x])\n",
        "\n",
        "val_df['signal'] = val_df['signal'].apply(lambda x: [(item - min_val_df) / (max_val_df - min_val_df) for item in x])\n",
        "\n",
        "test_df['signal'] = test_df['signal'].apply(lambda x: [(item - min_test_df) / (max_test_df - min_test_df) for item in x])\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "rP54exdx7ZFF",
      "metadata": {
        "id": "rP54exdx7ZFF"
      },
      "source": [
        "#Quantization\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "T9cwSwdnUrl7",
      "metadata": {
        "id": "T9cwSwdnUrl7"
      },
      "outputs": [],
      "source": [
        "# from https://github.com/joaomrcarvalho/diffquantizer.git\n",
        "\n",
        "class DiffQuantizer:\n",
        "    def __init__(self, alphabet_size, average_over=1, filter=False, breakpoints=None, use_diffs=True):\n",
        "\n",
        "        self.alphabet_size = alphabet_size\n",
        "        self.average_over = average_over\n",
        "        self.use_filter = filter\n",
        "        self.breakpoints = breakpoints\n",
        "        self.use_diffs = use_diffs\n",
        "        # print('self.breakpoints:',self.breakpoints)\n",
        "\n",
        "    def preprocess(self, tmp):\n",
        "        if self.average_over != 1:\n",
        "            tmp = self._average_over_n(tmp, self.average_over)\n",
        "\n",
        "        if self.use_filter:\n",
        "            tmp = self._filter_signal(tmp)\n",
        "\n",
        "        if self.use_diffs:\n",
        "            tmp = self._diff_signal(tmp)\n",
        "\n",
        "        return tmp\n",
        "\n",
        "    def perform_quantization(self, tmp, breakpoints=None):\n",
        "        self.breakpoints = breakpoints\n",
        "        # print('perform_quantization, self.breakpoints:',self.breakpoints)\n",
        "        tmp = self.preprocess(tmp)\n",
        "        result = self._quantize_with_breakpoints(tmp)\n",
        "        return result\n",
        "\n",
        "    def learn_breakpoints(self, arr):\n",
        "        res = self.preprocess(arr)\n",
        "\n",
        "        sorted_array = np.sort(res)\n",
        "\n",
        "        length = len(sorted_array)\n",
        "\n",
        "        probs = [1 / self.alphabet_size for _ in range(self.alphabet_size)]\n",
        "        cum_sum_breakpoints = [int(sum(probs[0:i + 1]) * length - 1) for i in range(len(probs))]\n",
        "        cum_sum_breakpoint_values = sorted_array[cum_sum_breakpoints]\n",
        "\n",
        "        cum_sum_breakpoint_values[-1] = 1e+100\n",
        "\n",
        "        self.breakpoints = cum_sum_breakpoint_values\n",
        "        # print('learn_breakpoints,self.breakpoints',self.breakpoints)\n",
        "\n",
        "        return cum_sum_breakpoint_values\n",
        "\n",
        "    # vectorized use\n",
        "    @staticmethod\n",
        "    def _breakpoint_to_letter(float_num, breakpoints):\n",
        "        # print('float_num:',float_num,'breakpoints:',breakpoints)\n",
        "        # print(list((breakpoints.index(obj) for obj in breakpoints if float_num < obj)))\n",
        "        int_val = next((breakpoints.index(obj) for obj in breakpoints if float_num < obj))\n",
        "        # print(int_val,list((breakpoints.index(obj) for obj in breakpoints if float_num < obj)))\n",
        "        # A + int_val\n",
        "        return chr(65 + int_val)\n",
        "\n",
        "    def _quantize_with_breakpoints(self, tmp):\n",
        "        breakpoints = self.breakpoints\n",
        "        vect_breakpoint_to_letter = np.vectorize(self._breakpoint_to_letter, excluded=['breakpoints'])\n",
        "        # print(tmp,breakpoints,self._breakpoint_to_letter,vect_breakpoint_to_letter)\n",
        "        tmp = vect_breakpoint_to_letter(tmp, breakpoints=list(breakpoints))\n",
        "        # tmp = vect_breakpoint_to_letter(tmp, breakpoints=breakpoints)\n",
        "        return tmp\n",
        "\n",
        "    @staticmethod\n",
        "    def _read_csv_file(input_file):\n",
        "        tmp_file_content = pd.read_csv(input_file, sep=\"\\n\", header=None, dtype=np.float64)\n",
        "        return np.array(tmp_file_content)\n",
        "\n",
        "    ## @staticmethod\n",
        "    # def _filter_signal(tmp):\n",
        "    #     return butter_lowpass_filter(tmp)\n",
        "\n",
        "    @staticmethod\n",
        "    def _average_over_n(tmp, n):\n",
        "        return np.array([np.average(tmp[i:i + n]) for i in range(0, len(tmp), n)])\n",
        "\n",
        "    @staticmethod\n",
        "    def _diff_signal(tmp):\n",
        "        res = np.diff(tmp)\n",
        "        return np.insert(res, 0, 0.0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "FP5-zJfHbkdg",
      "metadata": {
        "id": "FP5-zJfHbkdg"
      },
      "outputs": [],
      "source": [
        "import math\n",
        "import csv\n",
        "\n",
        "class Preprocessing():\n",
        "\n",
        "  def __init__(self, input_file, discretizition_factor, max_window_size):\n",
        "    self.input_file = input_file\n",
        "    self.discretizition_factor = discretizition_factor\n",
        "    self.max_window_size = max_window_size\n",
        "\n",
        "\n",
        "  def create_window(self):\n",
        "\n",
        "    ecg_window = [] #ecg_window: list of windows\n",
        "\n",
        "    mlen=0\n",
        "    ecg_list = self.input_file['signal']\n",
        "    if len(ecg_list) > mlen:\n",
        "      mlen= len(ecg_list)\n",
        "\n",
        "    if mlen > self.max_window_size:\n",
        "      window_size = self.max_window_size\n",
        "    else:\n",
        "      window_size = mlen\n",
        "\n",
        "    num_lines = math.floor(len(ecg_list)/window_size) #in this case: 1\n",
        "\n",
        "    for i in range(num_lines):\n",
        "      tmp_list = ecg_list[i*window_size:(i+1)*window_size]\n",
        "      ecg_window.append(tmp_list)\n",
        "\n",
        "    return ecg_window\n",
        "\n",
        "  def change_to_alphabet(self, quantizer, normalized_list):\n",
        "    qtz_signal = []\n",
        "    labels = []\n",
        "    qtz = quantizer\n",
        "    for i in range(len(normalized_list)): #i is for each line\n",
        "      r = qtz.perform_quantization(np.array(normalized_list[i]),breakpoints=qtz.breakpoints)\n",
        "      # print('result:',r.shape)\n",
        "      # make r from list of chars  to string a chars by ''.joint(r)\n",
        "      qtz_signal.append(''.join(r))\n",
        "      labels.append(self.input_file['label'])\n",
        "\n",
        "    return qtz_signal, labels"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "OQTmc7tdvUeI",
      "metadata": {
        "id": "OQTmc7tdvUeI"
      },
      "outputs": [],
      "source": [
        "### lloyd_max\n",
        "def discretization_lloyd_max(discretizition_factor, total_data):\n",
        "  qtz_signal = []\n",
        "  labels = []\n",
        "  qtz=DiffQuantizer(alphabet_size=discretizition_factor,breakpoints=None,use_diffs=False)\n",
        "  qtz.learn_breakpoints(np.array(total_data))\n",
        "  return qtz"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "Q_D6QlzN-lNN",
      "metadata": {
        "id": "Q_D6QlzN-lNN"
      },
      "outputs": [],
      "source": [
        "def run_Preprocessing(db, max_window_size=7500, discretizition_factor=100):\n",
        "    discretizition_factor = discretizition_factor\n",
        "    max_window_size = max_window_size\n",
        "    nl_list, total_data = [], []\n",
        "    r_list, labels = [], []\n",
        "\n",
        "    for n in range(len(db)):\n",
        "      pre = Preprocessing(db.loc[n], discretizition_factor, max_window_size)\n",
        "      nl = pre.create_window()\n",
        "      nl_list.append(nl)\n",
        "\n",
        "    for i in range(len(nl_list)):\n",
        "        total_data.extend(nl_list[i][0])\n",
        "\n",
        "    quantize_max_lloyd_on_total_data = discretization_lloyd_max(discretizition_factor, total_data)\n",
        "\n",
        "    for n in range(len(db)):\n",
        "      pre = Preprocessing(db.loc[n], discretizition_factor, max_window_size)\n",
        "      nl = pre.create_window()\n",
        "      r, l = pre.change_to_alphabet(quantize_max_lloyd_on_total_data, nl)\n",
        "      r_list.extend(r)\n",
        "      labels.extend(l)\n",
        "    print(n, \" : done\")\n",
        "\n",
        "    return r_list, labels"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "Qaleio2RaHuG",
      "metadata": {
        "id": "Qaleio2RaHuG"
      },
      "source": [
        "`check max_window_size`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8kfdJIhbOgzE",
      "metadata": {
        "id": "8kfdJIhbOgzE"
      },
      "outputs": [],
      "source": [
        "r_list_train, labels_train = run_Preprocessing(db=train_df, max_window_size=1080, discretizition_factor=100)\n",
        "\n",
        "r_list_val, labels_val = run_Preprocessing(db=val_df, max_window_size=1080, discretizition_factor=100)\n",
        "\n",
        "r_list_test, labels_test = run_Preprocessing(db=test_df, max_window_size=1080, discretizition_factor=100)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "P_EiiYPF0Yec",
      "metadata": {
        "id": "P_EiiYPF0Yec"
      },
      "source": [
        "# Tokenizer"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "fT5QosR96oJe",
      "metadata": {
        "id": "fT5QosR96oJe"
      },
      "outputs": [],
      "source": [
        "from transformers import AutoTokenizer, AutoModel"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "4nKlFTVqtMNr",
      "metadata": {
        "id": "4nKlFTVqtMNr"
      },
      "outputs": [],
      "source": [
        "tokenizer_ft = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=\"/path...\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "Waf_u_mS6m2n",
      "metadata": {
        "id": "Waf_u_mS6m2n"
      },
      "outputs": [],
      "source": [
        "def tokenize_function(examples, max_length = 512):\n",
        "    return tokenizer_ft(examples, padding=\"max_length\", truncation=True, max_length=max_length)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "wVCCbiBY6m2n",
      "metadata": {
        "id": "wVCCbiBY6m2n"
      },
      "outputs": [],
      "source": [
        "tokenized_dataset_train = list(map(tokenize_function, r_list_train))\n",
        "\n",
        "tokenized_dataset_val = list(map(tokenize_function, r_list_val))\n",
        "\n",
        "tokenized_dataset_test = list(map(tokenize_function, r_list_test))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "niQm13G2qEfe",
      "metadata": {
        "id": "niQm13G2qEfe"
      },
      "source": [
        "## input_ids, attention_masks, labels"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "lGGbYR7spHjt",
      "metadata": {
        "id": "lGGbYR7spHjt"
      },
      "outputs": [],
      "source": [
        "def return_ids(tokenized_dataset):\n",
        "\n",
        "  dataset = []\n",
        "  input_ids = []\n",
        "  attention_masks = []\n",
        "\n",
        "  for i in range(len(tokenized_dataset)):\n",
        "\n",
        "      dataset.append(tokenized_dataset[i])\n",
        "      del dataset[-1]['token_type_ids']\n",
        "\n",
        "      input_ids.append(np.array(dataset[-1]['input_ids'], dtype=np.int32))\n",
        "\n",
        "      attention_masks.append(np.array(dataset[-1]['attention_mask'], dtype=bool))\n",
        "\n",
        "      dataset = []\n",
        "\n",
        "  input_ids = np.array(input_ids)\n",
        "  attention_masks = np.array(attention_masks)\n",
        "\n",
        "  return input_ids, attention_masks"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "NZFMyJCntiwk",
      "metadata": {
        "id": "NZFMyJCntiwk"
      },
      "outputs": [],
      "source": [
        "input_ids_train, attention_masks_train = return_ids(tokenized_dataset_train)\n",
        "\n",
        "input_ids_val, attention_masks_val = return_ids(tokenized_dataset_val)\n",
        "\n",
        "input_ids_test, attention_masks_test = return_ids(tokenized_dataset_test)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "OG17Rq5tuq0K",
      "metadata": {
        "id": "OG17Rq5tuq0K"
      },
      "outputs": [],
      "source": [
        "labels_train = np.array(labels_train,dtype=np.int8)\n",
        "\n",
        "labels_val = np.array(labels_val,dtype=np.int8)\n",
        "\n",
        "labels_test = np.array(labels_test,dtype=np.int8)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "bEppEuYow1Wy",
      "metadata": {
        "id": "bEppEuYow1Wy"
      },
      "source": [
        "# DataLoader"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ajDUtGJYPNws",
      "metadata": {
        "id": "ajDUtGJYPNws"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import TensorDataset, DataLoader"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "eQq23H8MO_6q",
      "metadata": {
        "id": "eQq23H8MO_6q"
      },
      "outputs": [],
      "source": [
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "j_ovU5k_FQvH",
      "metadata": {
        "id": "j_ovU5k_FQvH"
      },
      "outputs": [],
      "source": [
        "bert_model = AutoModel.from_pretrained(\"/path...\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "SvpgO2WaO06I",
      "metadata": {
        "id": "SvpgO2WaO06I"
      },
      "outputs": [],
      "source": [
        "bert_model = bert_model.to(device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "n133rpUCyZ-B",
      "metadata": {
        "id": "n133rpUCyZ-B"
      },
      "outputs": [],
      "source": [
        "bert_model.eval()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "Kr2lzUUCL3Is",
      "metadata": {
        "id": "Kr2lzUUCL3Is"
      },
      "outputs": [],
      "source": [
        "for param in bert_model.parameters():\n",
        "  param.requires_grad = False"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "-8-Cp5p8xUbE",
      "metadata": {
        "id": "-8-Cp5p8xUbE"
      },
      "outputs": [],
      "source": [
        "batch_size = 8"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "LR64qTZ-f8YP",
      "metadata": {
        "id": "LR64qTZ-f8YP"
      },
      "outputs": [],
      "source": [
        "dataset_ = TensorDataset(torch.tensor(input_ids_train), torch.tensor(labels_train),\n",
        "                         torch.tensor(attention_masks_train))\n",
        "dataloader = DataLoader(dataset_, batch_size=batch_size, shuffle = True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "p_EFjQh5gS9F",
      "metadata": {
        "id": "p_EFjQh5gS9F"
      },
      "outputs": [],
      "source": [
        "dataset_valid = TensorDataset(torch.tensor(input_ids_val), torch.tensor(labels_val),\n",
        "                          torch.tensor(attention_masks_val))\n",
        "dataloader_valid = DataLoader(dataset_valid, batch_size=batch_size, shuffle = True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "aKObaAVmgUql",
      "metadata": {
        "id": "aKObaAVmgUql"
      },
      "outputs": [],
      "source": [
        "dataset_test = TensorDataset(torch.tensor(input_ids_test), torch.tensor(labels_test),\n",
        "                          torch.tensor(attention_masks_test))\n",
        "dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle = False)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "HOdVKiqgW_KC",
      "metadata": {
        "id": "HOdVKiqgW_KC"
      },
      "source": [
        "# BertBiLSTMClassifier"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "bNPHcnS-7_75",
      "metadata": {
        "id": "bNPHcnS-7_75"
      },
      "outputs": [],
      "source": [
        "class BertBiLSTMClassifier(nn.Module):\n",
        "    def __init__(self, num_classes, bert_frozen_layers=3, input_size=768, lstm_hidden_size=128):\n",
        "        # you can set bert_frozen_layers to an arbitrary number\n",
        "        super(BertBiLSTMClassifier, self).__init__()\n",
        "\n",
        "        # HeartBERT model with frozen layers\n",
        "        self.bert = AutoModel.from_pretrained(pretrained_model_name_or_path=\"/path...\")\n",
        "\n",
        "        modules = [self.bert.embeddings, *self.bert.encoder.layer[:bert_frozen_layers]]\n",
        "        for module in modules:\n",
        "          for param in module.parameters():\n",
        "            param.requires_grad = False\n",
        "        # Bi-LSTM layer\n",
        "        self.lstm = nn.LSTM(input_size=input_size, hidden_size=lstm_hidden_size, bidirectional=True, batch_first=True)\n",
        "\n",
        "        # Fully connected layer for classification\n",
        "        self.fc = nn.Linear(lstm_hidden_size * 2, num_classes)\n",
        "      #  self.softmax = nn.Softmax(dim=1)\n",
        "\n",
        "    def forward(self, input_ids):\n",
        "        # BERT forward pass\n",
        "        bert_output = self.bert(input_ids)[0]\n",
        "        # Bi-LSTM forward pass\n",
        "        lstm_out, _ = self.lstm(bert_output)\n",
        "        # Use the last hidden state from Bi-LSTM\n",
        "        lstm_last_hidden_state = lstm_out[:, -1, :]\n",
        "        # Classification using fully connected layer\n",
        "        res = self.fc(lstm_last_hidden_state)\n",
        "\n",
        "      #  res = self.softmax(logits)\n",
        "\n",
        "        return res"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0fmIo3rS9HBn",
      "metadata": {
        "id": "0fmIo3rS9HBn"
      },
      "outputs": [],
      "source": [
        "model = BertBiLSTMClassifier(4)\n",
        "model = model.to(device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ZioFSxiHsHXS",
      "metadata": {
        "id": "ZioFSxiHsHXS"
      },
      "outputs": [],
      "source": [
        "trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
        "trainable_params"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "3NeBRPEiPTfA",
      "metadata": {
        "id": "3NeBRPEiPTfA"
      },
      "outputs": [],
      "source": [
        "criterion = nn.CrossEntropyLoss()\n",
        "optimizer = optim.Adam(model.parameters(), lr=3e-4) # setup the proper lr"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "WoldsV_pGJ_I",
      "metadata": {
        "id": "WoldsV_pGJ_I"
      },
      "source": [
        "# Metrics"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "tclmuyY8aofz",
      "metadata": {
        "id": "tclmuyY8aofz"
      },
      "outputs": [],
      "source": [
        "from sklearn.metrics import precision_recall_fscore_support, accuracy_score, confusion_matrix\n",
        "import seaborn as sns\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "def compute_metrics(eval_pred, labels):\n",
        "\n",
        "    predictions = np.argmax(eval_pred, axis=-1)\n",
        "\n",
        "    # Compute precision, recall, F1, and accuracy\n",
        "    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average=None)\n",
        "    accuracy = accuracy_score(labels, predictions)\n",
        "\n",
        "    # Compute micro and macro averages\n",
        "    micro_precision, micro_recall, micro_f1, _ = precision_recall_fscore_support(labels, predictions, average='micro')\n",
        "    macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(labels, predictions, average='macro')\n",
        "\n",
        "    print(\"Class-wise Metrics:\")\n",
        "    for i in range(len(precision)):\n",
        "        print(f\"Class {i+1}: Precision={precision[i]:.4f}, Recall={recall[i]:.4f}, F1={f1[i]:.4f}\")\n",
        "\n",
        "    print(\"\\nMicro Average Metrics:\")\n",
        "    print(f\"Precision={micro_precision:.4f}, Recall={micro_recall:.4f}, F1={micro_f1:.4f}\")\n",
        "\n",
        "    print(\"\\nMacro Average Metrics:\")\n",
        "    print(f\"Precision={macro_precision:.4f}, Recall={macro_recall:.4f}, F1={macro_f1:.4f}\")\n",
        "\n",
        "    print(\"\\nAccuracy:\")\n",
        "    print(f\"Accuracy={accuracy:.4f}\")\n",
        "\n",
        "def display_confusion_matrix(eval_pred, labels):\n",
        "    num_classes = 4\n",
        "\n",
        "    predictions = np.argmax(eval_pred, axis=-1)\n",
        "\n",
        "    cm = confusion_matrix(labels, predictions)\n",
        "\n",
        "    # Plot confusion matrix\n",
        "    plt.figure(figsize=(num_classes, num_classes))\n",
        "\n",
        "    sns.heatmap(cm, annot=True, fmt=\"d\", cmap=\"Blues\", cbar=False,\n",
        "                xticklabels=np.arange(1, num_classes+1),\n",
        "                yticklabels=np.arange(1, num_classes+1))\n",
        "\n",
        "    plt.xlabel('Predicted')\n",
        "    plt.ylabel('Actual')\n",
        "    plt.title('Confusion Matrix')\n",
        "    plt.show()\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "5JFeAu-0n1g3",
      "metadata": {
        "id": "5JFeAu-0n1g3"
      },
      "source": [
        "# Train"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "sjov9JAXPeEA",
      "metadata": {
        "id": "sjov9JAXPeEA"
      },
      "outputs": [],
      "source": [
        "loss_epoch_train = []\n",
        "loss_epoch_valid = []\n",
        "num_epochs = 50"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "GlBLq8OhkAuC",
      "metadata": {
        "id": "GlBLq8OhkAuC"
      },
      "outputs": [],
      "source": [
        "from numpy import mean\n",
        "total_train_acc = np.zeros((num_epochs,))\n",
        "total_valid_acc = np.zeros((num_epochs,))\n",
        "for epoch in range(num_epochs):\n",
        "  loss_batch_train = []\n",
        "  correct_train = []\n",
        "  correct_valid = []\n",
        "  accuracy_train = []\n",
        "  model.train()\n",
        "  for batch in dataloader:\n",
        "\n",
        "    inputs, labels, _ = batch\n",
        "\n",
        "    inputs = inputs.to(device)\n",
        "    labels = labels.to(device)\n",
        "    outputs = model(inputs)\n",
        "    loss = criterion(outputs.to(device), labels.type(torch.LongTensor).to(device))\n",
        "    predicted = torch.argmax(outputs.data, dim=1)\n",
        "    correct_train.extend(predicted.eq(labels.to(device).data).float())\n",
        "    optimizer.zero_grad()\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "\n",
        "    loss_batch_train.append(loss.item())\n",
        "\n",
        "  # if (epoch%10==0):\n",
        "  #   torch.save(model.state_dict(), '/content/drive/MyDrive/'+str(epoch)+'.pt')\n",
        "\n",
        "  total_train_acc[epoch] = torch.tensor(correct_train).mean().item()\n",
        "  loss_train_per_epoch = mean(loss_batch_train)\n",
        "  loss_epoch_train.append(loss_train_per_epoch)\n",
        "  print('train-loss',epoch,':', loss_epoch_train[-1])\n",
        "\n",
        "  model.eval()\n",
        "  loss_batch_valid = []\n",
        "\n",
        "  for batch in dataloader_valid:\n",
        "    inputs, labels, _ = batch\n",
        "    inputs = inputs.to(device)\n",
        "    labels = labels.to(device)\n",
        "    outputs = model(inputs)\n",
        "    loss = criterion(outputs.to(device), labels.type(torch.LongTensor).to(device))\n",
        "    predicted = torch.argmax(outputs.data, dim=1)\n",
        "    correct_valid.extend(predicted.eq(labels.to(device).data).float())\n",
        "    loss_batch_valid.append(loss.item())\n",
        "\n",
        "  total_valid_acc[epoch] = torch.tensor(correct_valid).mean().item()\n",
        "  loss_valid_per_epoch = mean(loss_batch_valid)\n",
        "  loss_epoch_valid.append(loss_valid_per_epoch)\n",
        "  print(total_train_acc,total_valid_acc)\n",
        "  print('valid-loss',epoch,':',loss_epoch_valid[-1])\n",
        "  print('**********')\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "fNrT4XQJ0pTJ",
      "metadata": {
        "id": "fNrT4XQJ0pTJ"
      },
      "outputs": [],
      "source": [
        "print(total_train_acc,total_valid_acc)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "VX8bpLvQztZ1",
      "metadata": {
        "id": "VX8bpLvQztZ1"
      },
      "source": [
        "## Plot"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "pXNF9gthzs6k",
      "metadata": {
        "id": "pXNF9gthzs6k"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "plt.plot(loss_epoch_train,'r')\n",
        "plt.plot(loss_epoch_valid,'b')\n",
        "plt.title('train-loss vs eval-loss')\n",
        "plt.xlabel('epoch')\n",
        "plt.ylabel('loss')\n",
        "plt.legend(['loss_train','loss_eval']);"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "EzoTirhWLcBN",
      "metadata": {
        "id": "EzoTirhWLcBN"
      },
      "outputs": [],
      "source": [
        "plt.plot(total_train_acc)\n",
        "plt.plot(total_valid_acc)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "GwYU7iwav0p1",
      "metadata": {
        "id": "GwYU7iwav0p1"
      },
      "source": [
        "# Test"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "DvRZWk0DpxSq",
      "metadata": {
        "id": "DvRZWk0DpxSq"
      },
      "outputs": [],
      "source": [
        "model.eval()\n",
        "loss_batch_test = []\n",
        "\n",
        "outputs_temp=torch.zeros((1,outputs.size()[1])).to(device)\n",
        "labels_temp = torch.zeros((1,))\n",
        "\n",
        "for batch in dataloader_test:\n",
        "  with torch.no_grad():\n",
        "    inputs, labels, _ = batch\n",
        "    inputs = inputs.to(device)\n",
        "    labels = labels.to(device)\n",
        "    # attention = attention.type(torch.IntTensor).to(device)\n",
        "    outputs = model(inputs)\n",
        "    predicted = torch.argmax(outputs.data, dim=1)\n",
        "    correct_train.extend(predicted.eq(labels.to(device).data).float())\n",
        "\n",
        "  outputs_temp = torch.cat((outputs_temp,outputs),dim=0)\n",
        "  labels_temp = torch.cat((labels_temp,labels.to('cpu')),dim=0)\n",
        "\n",
        "  loss = criterion(outputs.to(device), labels.type(torch.LongTensor).to(device))\n",
        "  loss_batch_test.append(loss.item())\n",
        "\n",
        "print('test-loss: ', np.mean(loss_batch_test))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "XU68XuyKLXjo",
      "metadata": {
        "id": "XU68XuyKLXjo"
      },
      "outputs": [],
      "source": [
        "compute_metrics(outputs_temp[1:].to('cpu'),labels_temp[1:].to('cpu'))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "-5cG7fZxp4BZ",
      "metadata": {
        "id": "-5cG7fZxp4BZ"
      },
      "outputs": [],
      "source": [
        "display_confusion_matrix(outputs_temp[1:].to('cpu'),labels_temp[1:].to('cpu'))"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}